from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from collections import defaultdict
import argparse
import cv2  # NOQA (Must import before importing caffe2 due to bug in cv2)
import glob
import logging
import os
import sys
import time

from caffe2.python import workspace

from detectron.core.config import assert_and_infer_cfg
from detectron.core.config import cfg
from detectron.core.config import merge_cfg_from_file
from detectron.utils.io import cache_url
from detectron.utils.logging import setup_logging
from detectron.utils.timer import Timer
import detectron.core.test_engine as infer_engine
import detectron.datasets.dummy_datasets as dummy_datasets
import detectron.utils.c2 as c2_utils
import detectron.utils.vis as vis_utils

c2_utils.import_detectron_ops()

# OpenCL may be enabled by default in OpenCV3; disable it because it's not
# thread safe and causes unwanted GPU memory allocations.
cv2.ocl.setUseOpenCL(False)


def parse_args():
    parser = argparse.ArgumentParser(description='End-to-end inference')
    parser.add_argument(
        '--cfg',
        dest='cfg',
        help='cfg model file (/path/to/model_config.yaml)',
        default=None,
        type=str
    )
    parser.add_argument(
        '--wts',
        dest='weights',
        help='weights model file (/path/to/model_weights.pkl)',
        default=None,
        type=str
    )
    parser.add_argument(
        '--output-dir',
        dest='output_dir',
        help='directory for visualization pdfs (default: /tmp/infer_simple)',
        default='/tmp/infer_simple',
        type=str
    )
    parser.add_argument(
        '--image-ext',
        dest='image_ext',
        help='image file name extension (default: jpg)',
        default='jpg',
        type=str
    )
    parser.add_argument(
        '--always-out',
        dest='out_when_no_box',
        help='output image even when no object is found',
        action='store_true'
    )
    parser.add_argument(
        '--output-ext',
        dest='output_ext',
        help='output image file format (default: pdf)',
        default='pdf',
        type=str
    )
    parser.add_argument(
        '--thresh',
        dest='thresh',
        help='Threshold for visualizing detections',
        default=0.7,
        type=float
    )
    parser.add_argument(
        '--kp-thresh',
        dest='kp_thresh',
        help='Threshold for visualizing keypoints',
        default=2.0,
        type=float
    )
    parser.add_argument(
        'im_or_folder', help='image or folder of images', default=None
    )
    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(1)
    return parser.parse_args()


import pycocotools.mask as mask_util
import matplotlib.pyplot as plt
import numpy as np 
import pandas as pd
np.set_printoptions(threshold=np.inf)

class Airbus_Submit(object):
    def __init__(self, thresh=0.4, csv_file_name='rle.csv'):
        self.thresh = thresh
        self.csv_file_name = csv_file_name

        self.csv_img  = []  # save img name
        self.csv_rle  = []  # save rle result
        self.csv_con  = []  # save confidence
        self.csv_area = []  # save the area of mask

    def extract_result(self, cls_boxes, cls_segms, cls_keyps, im_real_name, confidence):
        if isinstance(cls_boxes, list):
            boxes, segms, keypoints, classes = self.convert_from_cls_format(cls_boxes, cls_segms, cls_keyps)
        if (boxes is None or boxes.shape[0] == 0 or max(boxes[:, 4]) < self.thresh):
            return
        if segms is not None and len(segms) > 0:
            masks = np.array(mask_util.decode(segms))
        if(masks is None):  # it means in the pic, confidence of all possible objects < thresh
            return
        self.mask_to_rle_csv(im_real_name, masks, confidence)
        
    def convert_from_cls_format(self, cls_boxes, cls_segms, cls_keyps):
        # Convert from the class boxes/segms/keyps format generated by the testing code.
        box_list = [b for b in cls_boxes if len(b) > 0]
        if len(box_list) > 0:
            boxes = np.concatenate(box_list)
        else:
            boxes = None
        if cls_segms is not None:
            segms = [s for slist in cls_segms for s in slist]
        else:
            segms = None
        if cls_keyps is not None:
            keyps = [k for klist in cls_keyps for k in klist]
        else:
            keyps = None
        classes = []
        for j in range(len(cls_boxes)):
            classes += [j] * len(cls_boxes[j])
        return boxes, segms, keyps, classes

    def rle_encode(self, img):
        pixels = img.T.flatten()    # T is needed here.
        pixels = np.concatenate([[0], pixels, [0]])
        runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
        runs[1::2] -= runs[::2]
        return ' '.join(str(x) for x in runs)

    def mask_to_rle_csv(self, img, masks,confidence):
        index = np.argsort(-confidence)                 # index sorted by confidence: high-->low
        bg = np.zeros((768,768), dtype=np.uint8)            # must let dtype=uint8.Otherwise xor will be wrong
        for i in index:
            mask = masks[:,:,i]
            if(mask is None or confidence[i]< self.thresh ):     # Sometimes mask maybe None, but can't use 'mask==None', if confidence[i]<0.5 it's impossible to use
                continue
            mask_xor = (mask^bg)&mask
            area = mask_xor.sum()                       # area of the mask
            if(area == 0):
                continue
            print(confidence[i])
            rle = self.rle_encode(mask_xor)
            bg += mask_xor

            self.csv_img.append(img)
            self.csv_rle.append(rle)
            self.csv_con.append(confidence[i])
            self.csv_area.append(area)
            
    def create_csv(self):
        df = pd.DataFrame({'ImageId':self.csv_img, 'EncodedPixels':self.csv_rle, 'confidence':self.csv_con, 'area':self.csv_area})
        df = df[['ImageId', 'EncodedPixels', 'confidence', 'area']]   # change the column index
        df.to_csv(self.csv_file_name, index=False, sep=str(','))
        print("%s is written successfully."%self.csv_file_name)


def main(args):
    logger = logging.getLogger(__name__)

    merge_cfg_from_file(args.cfg)
    cfg.NUM_GPUS = 2
    args.weights = cache_url(args.weights, cfg.DOWNLOAD_CACHE)
    assert_and_infer_cfg(cache_urls=False)

    assert not cfg.MODEL.RPN_ONLY, \
        'RPN models are not supported'
    assert not cfg.TEST.PRECOMPUTED_PROPOSALS, \
        'Models that require precomputed proposals are not supported'

    model = infer_engine.initialize_model_from_cfg(args.weights)
    dummy_coco_dataset = dummy_datasets.get_coco_dataset()
    if os.path.isdir(args.im_or_folder):
        im_list = glob.iglob(args.im_or_folder + '/*.' + args.image_ext)
    else:
        im_list = [args.im_or_folder]

    num = 0

    airbus = Airbus_Submit(thresh=0.7, csv_file_name='../rle.csv')

    for i, im_name in enumerate(im_list):
 
        out_name = os.path.join(
            args.output_dir, '{}'.format(os.path.basename(im_name) + '.' + args.output_ext)
        )
        logger.info('Processing {} -> {}'.format(im_name, out_name))
        im = cv2.imread(im_name)
        timers = defaultdict(Timer)
        t = time.time()
        with c2_utils.NamedCudaScope(0):
            cls_boxes, cls_segms, cls_keyps = infer_engine.im_detect_all(
                model, im, None, timers=timers
            )
        logger.info('Inference time: {:.3f}s'.format(time.time() - t))
        for k, v in timers.items():
            logger.info(' | {}: {:.3f}s'.format(k, v.average_time))
        if i == 0:
            logger.info(
                ' \ Note: inference on the first image will be slower than the '
                'rest (caches and auto-tuning need to warm up)'
            )

        num = num+1
        print(num,'/2924')

        airbus.extract_result(cls_boxes, cls_segms, cls_keyps,
            im_real_name=im_name.split('/')[-1], confidence=cls_boxes[1][:,4])
    airbus.create_csv()


if __name__ == '__main__':
    workspace.GlobalInit(['caffe2', '--caffe2_log_level=0'])
    setup_logging(__name__)
    args = parse_args()
    main(args)
